{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a0cb395f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from d2l import torch as d2l\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from torch.utils import data\n",
    "import wandb\n",
    "\n",
    "NUM_SAVE = 50\n",
    "net_list = \"in->256->64\"\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_features):\n",
    "        super().__init__()\n",
    "        self.layer1 = nn.Linear(in_features,256)\n",
    "        self.layer2 = nn.Linear(256,64)\n",
    "        self.out = nn.Linear(64,1)\n",
    "        \n",
    "    def forward(self, X):\n",
    "        X = F.relu(self.layer1(X))\n",
    "        X = F.relu(self.layer2(X))\n",
    "        return self.out(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0de8a42c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data and test_data shape (47439, 41) (31626, 40)\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = torch.device(\"cpu\")\n",
    "test_data = pd.read_csv('test.csv')\n",
    "train_data = pd.read_csv('train.csv')\n",
    "print(\"train_data and test_data shape\",train_data.shape,test_data.shape)\n",
    "\n",
    "# 去掉冗余数据\n",
    "redundant_cols = ['Address', 'Summary', 'City', 'State']\n",
    "for c in redundant_cols:\n",
    "    del test_data[c], train_data[c]\n",
    "    \n",
    "# 数据预处理\n",
    "large_vel_cols = ['Lot', 'Total interior livable area', 'Tax assessed value', 'Annual tax amount', 'Listed Price', 'Last Sold Price']\n",
    "for c in large_vel_cols:\n",
    "    train_data[c] = np.log(train_data[c]+1)\n",
    "    if c!='Sold Price':\n",
    "        test_data[c] = np.log(test_data[c]+1)\n",
    "\n",
    "# 把train和test去除id后放一起，train也要去掉label\n",
    "all_features = pd.concat((train_data.iloc[:,2:],test_data.iloc[:,1:]))\n",
    "\n",
    "# 时间数据赋日期格式\n",
    "all_features['Listed On'] = pd.to_datetime(all_features['Listed On'], format=\"%Y-%m-%d\")\n",
    "all_features['Last Sold On'] = pd.to_datetime(all_features['Last Sold On'], format=\"%Y-%m-%d\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "108a4adb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type                 174\n",
      "Heating              2660\n",
      "Cooling              911\n",
      "Parking              9913\n",
      "Bedrooms             278\n",
      "Region               1259\n",
      "Elementary School    3568\n",
      "Middle School        809\n",
      "High School          922\n",
      "Flooring             1740\n",
      "Heating features     1763\n",
      "Cooling features     596\n",
      "Appliances included  11290\n",
      "Laundry features     3031\n",
      "Parking features     9695\n"
     ]
    }
   ],
   "source": [
    "for in_object in all_features.dtypes[all_features.dtypes=='object'].index:\n",
    "    print(in_object.ljust(20),len(all_features[in_object].unique()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "cdd79ef7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before one hot code (79065, 19)\n",
      "after one hot code (79065, 470)\n",
      "train feature shape: torch.Size([47439, 470])\n",
      "test feature shape: torch.Size([31626, 470])\n",
      "train label shape: torch.Size([47439, 1])\n"
     ]
    }
   ],
   "source": [
    "# 查询数字列 ->缺失数据赋0 -> 归一化\n",
    "numeric_features = all_features.dtypes[all_features.dtypes == 'float64'].index\n",
    "all_features = all_features.fillna(method='bfill', axis=0).fillna(0)\n",
    "all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))\n",
    "\n",
    "features = list(numeric_features)\n",
    "features.extend(['Type','Bedrooms'])   # 加上类别数相对较少的Type, ,'Cooling features'\n",
    "all_features = all_features[features]\n",
    "\n",
    "print('before one hot code',all_features.shape)\n",
    "all_features = pd.get_dummies(all_features,dummy_na=True)\n",
    "all_features.shape\n",
    "print('after one hot code',all_features.shape)\n",
    "\n",
    "n_train = train_data.shape[0]\n",
    "train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float)\n",
    "print('train feature shape:', train_features.shape)\n",
    "test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float)\n",
    "print('test feature shape:', test_features.shape)\n",
    "train_labels = torch.tensor(train_data['Sold Price'].values.reshape(-1, 1), dtype=torch.float)\n",
    "print('train label shape:', train_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e14702f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                    Syncing run <strong><a href=\"https://wandb.ai/kinzhang/kaggle_predict/runs/3w0mpfv5\" target=\"_blank\">major-serenity-35</a></strong> to <a href=\"https://wandb.ai/kinzhang/kaggle_predict\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
       "\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "network: MLP(\n",
      "  (layer1): Linear(in_features=470, out_features=256, bias=True)\n",
      "  (layer2): Linear(in_features=256, out_features=64, bias=True)\n",
      "  (out): Linear(in_features=64, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.MSELoss()\n",
    "in_features = train_features.shape[1]\n",
    "net = MLP(in_features).to(device)\n",
    "\n",
    "def load_array(data_arrays, batch_size, is_train=True):  #@save\n",
    "    \"\"\"Construct a PyTorch data iterator.\"\"\"\n",
    "    dataset = data.TensorDataset(*data_arrays)\n",
    "    return data.DataLoader(dataset, batch_size, shuffle=is_train)\n",
    "\n",
    "def log_rmse(net, features, labels):\n",
    "    # 为了在取对数时进一步稳定该值，将小于1的值设置为1\n",
    "    clipped_preds = torch.clamp(net(features), 1, float('inf'))\n",
    "    rmse = torch.sqrt(criterion(torch.log(clipped_preds),\n",
    "                           torch.log(labels)))\n",
    "    return rmse.item()\n",
    "\n",
    "def train(net, train_features, train_labels, test_features, test_labels,\n",
    "          num_epochs, learning_rate, weight_decay, batch_size):\n",
    "    wandb.watch(net)\n",
    "    train_ls, test_ls = [], []\n",
    "    train_iter = load_array((train_features, train_labels), batch_size)\n",
    "    # 这里使用的是Adam优化算法\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay = weight_decay)\n",
    "    for epoch in tqdm(range(num_epochs)):\n",
    "        for X, y in train_iter:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net(X)\n",
    "            loss = criterion(outputs, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        record_loss = log_rmse(net.to('cpu'), train_features, train_labels)\n",
    "        wandb.log({'loss': record_loss,'epoch': epoch})\n",
    "        train_ls.append(record_loss)\n",
    "        if (epoch%NUM_SAVE==0 and epoch!=0) or (epoch==num_epochs-1):\n",
    "            torch.save(net.state_dict(),'checkpoint_'+str(epoch))\n",
    "            print('save checkpoints on:', epoch, 'rmse loss value is:', record_loss)\n",
    "        del X, y\n",
    "        net.to(device)\n",
    "    wandb.finish()\n",
    "    return train_ls, test_ls\n",
    "\n",
    "k, num_epochs, lr, weight_decay, batch_size = 5, 2000, 0.005, 0.05, 256\n",
    "wandb.init(project=\"kaggle_predict\",\n",
    "           config={ \"learning_rate\": lr,\n",
    "                    \"weight_decay\": weight_decay,\n",
    "                    \"batch_size\": batch_size,\n",
    "                    \"total_run\": num_epochs,\n",
    "                    \"network\": net_list}\n",
    "          )\n",
    "print(\"network:\",net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5990e4b8",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|██                                                                              | 51/2000 [00:57<34:54,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 50 rmse loss value is: 0.3729535639286041\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|███▉                                                                           | 101/2000 [01:53<36:36,  1.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 100 rmse loss value is: 0.24150870740413666\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|█████▉                                                                         | 151/2000 [02:53<37:27,  1.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 150 rmse loss value is: 0.23059451580047607\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███████▉                                                                       | 201/2000 [03:48<32:20,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 200 rmse loss value is: 0.2329198122024536\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█████████▉                                                                     | 251/2000 [04:45<33:53,  1.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 250 rmse loss value is: 0.23867167532444\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|███████████▉                                                                   | 301/2000 [05:40<32:33,  1.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 300 rmse loss value is: 0.2422315776348114\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█████████████▊                                                                 | 351/2000 [06:36<30:55,  1.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 350 rmse loss value is: 0.24655304849147797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████████████▊                                                               | 401/2000 [07:32<29:18,  1.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 400 rmse loss value is: 0.2534756362438202\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|█████████████████▊                                                             | 451/2000 [08:29<28:15,  1.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 450 rmse loss value is: 0.2810263931751251\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|███████████████████▊                                                           | 501/2000 [09:22<26:28,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 500 rmse loss value is: 0.3201005756855011\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|█████████████████████▊                                                         | 551/2000 [10:16<25:31,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 550 rmse loss value is: 0.34142759442329407\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███████████████████████▋                                                       | 601/2000 [11:11<29:23,  1.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 600 rmse loss value is: 0.37291190028190613\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|█████████████████████████▋                                                     | 651/2000 [12:05<24:07,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 650 rmse loss value is: 0.41255420446395874\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███████████████████████████▋                                                   | 701/2000 [13:00<23:38,  1.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 700 rmse loss value is: 0.40421685576438904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|█████████████████████████████▋                                                 | 751/2000 [13:53<22:00,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 750 rmse loss value is: 0.4369213879108429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|███████████████████████████████▋                                               | 801/2000 [14:45<21:13,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 800 rmse loss value is: 0.4515615701675415\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 43%|█████████████████████████████████▌                                             | 851/2000 [15:41<21:13,  1.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 850 rmse loss value is: 0.5022755265235901\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|███████████████████████████████████▌                                           | 901/2000 [16:34<19:49,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 900 rmse loss value is: 0.5496918559074402\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|█████████████████████████████████████▌                                         | 951/2000 [17:28<18:53,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 950 rmse loss value is: 0.5296009182929993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|███████████████████████████████████████                                       | 1001/2000 [18:22<17:46,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1000 rmse loss value is: 0.5220655798912048\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|████████████████████████████████████████▉                                     | 1051/2000 [19:15<16:51,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1050 rmse loss value is: 0.527802050113678\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|██████████████████████████████████████████▉                                   | 1101/2000 [20:10<16:17,  1.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1100 rmse loss value is: 0.5796676874160767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|████████████████████████████████████████████▉                                 | 1151/2000 [21:05<15:24,  1.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1150 rmse loss value is: 0.5988300442695618\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████████████████████████████████████████████▊                               | 1201/2000 [21:59<14:31,  1.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1200 rmse loss value is: 0.580406904220581\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 63%|████████████████████████████████████████████████▊                             | 1251/2000 [22:53<13:43,  1.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1250 rmse loss value is: 0.5514773726463318\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████████████████████████████████████████████████▋                           | 1301/2000 [23:48<12:34,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 1300 rmse loss value is: 0.5342170596122742\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████████████████████████████████████████████████▉                           | 1305/2000 [23:53<12:43,  1.10s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-0a4bd784907c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtrain_ls\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalid_ls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_features\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrain_labels\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight_decay\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[1;31m# 使用现有训练好的net\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'cpu'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;31m# 将网络应用于测试集。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-7-387e97cb7ab5>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate, weight_decay, batch_size)\u001b[0m\n\u001b[0;32m     23\u001b[0m     \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight_decay\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mweight_decay\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     24\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 25\u001b[1;33m         \u001b[1;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     26\u001b[0m             \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     27\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    519\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 521\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    523\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    559\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    560\u001b[0m         \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 561\u001b[1;33m         \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    562\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    563\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 47\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py\u001b[0m in \u001b[0;36mdefault_collate\u001b[1;34m(batch)\u001b[0m\n\u001b[0;32m     82\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'each element in list of batch should be of equal size'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     83\u001b[0m         \u001b[0mtransposed\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 84\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mdefault_collate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0msamples\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtransposed\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     85\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     86\u001b[0m     \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdefault_collate_err_msg_format\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem_type\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     82\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'each element in list of batch should be of equal size'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     83\u001b[0m         \u001b[0mtransposed\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 84\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mdefault_collate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0msamples\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtransposed\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     85\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     86\u001b[0m     \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdefault_collate_err_msg_format\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem_type\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda\\envs\\sim_py38\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py\u001b[0m in \u001b[0;36mdefault_collate\u001b[1;34m(batch)\u001b[0m\n\u001b[0;32m     54\u001b[0m             \u001b[0mstorage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstorage\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_new_shared\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnumel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     55\u001b[0m             \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnew\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 56\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     57\u001b[0m     \u001b[1;32melif\u001b[0m \u001b[0melem_type\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__module__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'numpy'\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0melem_type\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;34m'str_'\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     58\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[0melem_type\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;34m'string_'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_ls, valid_ls = train(net, train_features,train_labels,None,None, num_epochs, lr, weight_decay, batch_size)\n",
    "\n",
    "# 使用现有训练好的net\n",
    "net.to('cpu')\n",
    "# 将网络应用于测试集。\n",
    "preds = net(test_features).detach().numpy()\n",
    "\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "571de62a",
   "metadata": {},
   "outputs": [],
   "source": [
    "net.to('cpu')\n",
    "preds = net(test_features).detach().numpy()\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dd76d443",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                    Syncing run <strong><a href=\"https://wandb.ai/kinzhang/kaggle_predict/runs/365nj2uh\" target=\"_blank\">visionary-meadow-37</a></strong> to <a href=\"https://wandb.ai/kinzhang/kaggle_predict\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
       "\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|                                                                                          | 0/500 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "network: MLP(\n",
      "  (layer1): Linear(in_features=470, out_features=256, bias=True)\n",
      "  (layer2): Linear(in_features=256, out_features=64, bias=True)\n",
      "  (out): Linear(in_features=64, out_features=1, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|████████▎                                                                        | 51/500 [00:55<07:58,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 50 rmse loss value is: 0.22980350255966187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|████████████████▏                                                               | 101/500 [01:49<07:06,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 100 rmse loss value is: 0.22977133095264435\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|████████████████████████▏                                                       | 151/500 [02:43<06:22,  1.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 150 rmse loss value is: 0.22944682836532593\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████████████████████████████████▏                                               | 201/500 [03:36<05:22,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 200 rmse loss value is: 0.2292274385690689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|████████████████████████████████████████▏                                       | 251/500 [04:30<04:29,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 250 rmse loss value is: 0.22929005324840546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|████████████████████████████████████████████████▏                               | 301/500 [05:24<03:31,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 300 rmse loss value is: 0.22912709414958954\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|████████████████████████████████████████████████████████▏                       | 351/500 [06:17<02:36,  1.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 350 rmse loss value is: 0.2287970930337906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████████████████████████████████████████████████████████████▏               | 401/500 [07:10<01:44,  1.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 400 rmse loss value is: 0.22854375839233398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|████████████████████████████████████████████████████████████████████████▏       | 451/500 [08:03<00:52,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 450 rmse loss value is: 0.22841306030750275\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [08:57<00:00,  1.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 499 rmse loss value is: 0.2283201962709427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<br/>Waiting for W&B process to finish, PID 24456... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c69e684a144454cb7f418940952f2e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\\r'), FloatProgress(value=1.0, max=1.0)…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\">\n",
       "<h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>loss</td><td>▇▇█▇█▆▇▇▆▇▆▆▆▅▅▅▅▅▅▅▄▄▅▄▄▃▂▃▃▃▃▃▃▂▃▂▂▂▁▁</td></tr></table><br/></div><div class=\"wandb-col\">\n",
       "<h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>499</td></tr><tr><td>loss</td><td>0.22832</td></tr></table>\n",
       "</div></div>\n",
       "Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
       "<br/>Synced <strong style=\"color:#cdcd00\">visionary-meadow-37</strong>: <a href=\"https://wandb.ai/kinzhang/kaggle_predict/runs/365nj2uh\" target=\"_blank\">https://wandb.ai/kinzhang/kaggle_predict/runs/365nj2uh</a><br/>\n",
       "Find logs at: <code>.\\wandb\\run-20211222_222917-365nj2uh\\logs</code><br/>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 读取已有 继续进行训练\n",
    "k, num_epochs, lr, weight_decay, batch_size = 5, 500, 0.0005, 0.08, 256\n",
    "wandb.init(project=\"kaggle_predict\",\n",
    "           config={ \"learning_rate\": lr,\n",
    "                    \"weight_decay\": weight_decay,\n",
    "                    \"batch_size\": batch_size,\n",
    "                    \"total_run\": num_epochs,\n",
    "                    \"network\": net_list}\n",
    "          )\n",
    "net.load_state_dict(torch.load('checkpoint_19676'))\n",
    "print(\"network:\",net)\n",
    "net.to(device)\n",
    "train_ls, valid_ls = train(net, train_features,train_labels,None,None, num_epochs, lr, weight_decay, batch_size)\n",
    "net.to('cpu')\n",
    "preds = net(test_features).detach().numpy()\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd9fb19a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取网络参数应用于测试集\n",
    "net = []\n",
    "net = MLP(test_features.shape[1])\n",
    "net.load_state_dict(torch.load('checkpoint_250'))\n",
    "net.to('cpu')\n",
    "preds = net(test_features).detach().numpy()\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8049d2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(all_features['Type'].unique()))\n",
    "print(len(all_features['Heating'].unique()))\n",
    "print(len(all_features['Cooling'].unique()))\n",
    "print(len(all_features['Parking'].unique()))\n",
    "print(len(all_features['Bedrooms'].unique()))\n",
    "print(len(all_features['Region'].unique()))\n",
    "print(len(all_features['Elementary School'].unique()))\n",
    "print(len(all_features['Middle School'].unique()))\n",
    "print(len(all_features['High School'].unique()))\n",
    "print(len(all_features['Flooring'].unique()))\n",
    "print(len(all_features['Heating features'].unique()))\n",
    "print(len(all_features['Cooling features'].unique()))\n",
    "print(len(all_features['Appliances included'].unique()))\n",
    "print(len(all_features['Laundry features'].unique()))\n",
    "print(len(all_features['Parking features'].unique()))\n",
    "print(len(all_features['City'].unique()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sim_py38",
   "language": "python",
   "name": "sim_py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
